import linearmodels as lm
import numpy as np
import seaborn as sns
import pandas as pd
from linearmodels import OLS
from matplotlib import rcParams
from matplotlib import pyplot as plt
import string
from os import path


class CodeBook:
    rnd = "subsession.round_number"  # 1, 7, 13 practice + 5 real rounds (MTurk 1 practice, 3 or 4 real)
    _treatment = "group.name_of_treatment"
    treatment = 'treatment'
    num_treatment = "num_treatment"
    woke = "woke_up"
    woke_time = "player.wake_up_time"
    settled = "settled"
    amount_settled = "player.accepted_offer"
    player_id = "participant.code"
    subgroup = 'group.id_in_subsession'  # takes value in 1, 2
    session = 'session_num'
    experience = 'experience'
    subgame = 'subgame'  # takes value in 1, 2, 3
    time_settled = 'settlement_time'
    rank = 'player.audit_location'
    investigated = 'player.investigated'
    rank_at_settlement = 'rank at settlement'
    rank_at_end = 'rank at end'
    rank_at_wake = 'rank at wake'
    slope = 'group.discount_slope'
    info = 'info'
    priority = 'priority'
    mid_discount = 'group.discount_slope_0.413793103448276'
    high_discount = 'group.discount_slope_0.7586206896551719'
    num_settled_at_wake = 'num_settled_at_wake'
    num_settled_at_settlement = 'num_settled_at_settlement'
    diff_rank_delay = 'diff_rank_delay'
    diff_settled_delay = 'diff_settled_delay'
    settlement_delay = 'settlement_delay'
    estimated_rank_at_settlement = 'estimated_rank_at_settlement'
    max_rank_at_settlement = 'max_rank_at_settlement'
    max_rank_at_wake = 'max_rank_at_wake'
    ext_time_dom = 'extended_time_dominant'
    time_dom = 'time_dominant'
    wake_up_dom = 'wake_up_dominant'
    dom = 'dominant'
    ext_settle_time = 'extended_settlement_time'
    ext_settle_delay = 'extended_settlement_delay'
    settle_delay_from_dom = 'settlement_delay_from_dominant'
    ext_settle_delay_from_dom = 'extended_settlement_delay_from_dominant'
    
CB = CodeBook


unit_playergroup_rnd = [CB.session, CB.rnd, CB.subgroup, CB.treatment]
unit_player_subgame = [CB.player_id, CB.subgame]


treatment_list = ['random', 'priority only', 'aggregate info', 'targeted info', 'targeted info high']


def to_treatment_name(x):
    if isinstance(x, str):
        x = string.ascii_lowercase.index(x.lower())
    return treatment_list[int(x)]


practice_rounds = [1, 7, 13]


def is_in(f):
    if callable(f):
        return f
    elif isinstance(f, list):
        return lambda x: x in f
    else:
        return lambda x: x == f


def is_less_than(x):
    return lambda y: y < x


def is_weakly_less_than(x):
    return lambda y: y <= x


def is_greater_than(x):
    return lambda y: y > x


def is_weakly_greater_than(x):
    return lambda y: y >= x


def filter_data(data, **kwargs):
    this_df = data.copy()
    for key in kwargs.keys():
        full_key = key if key in this_df.columns else CB.__dict__[key]
        this_df = this_df.loc[this_df[full_key].apply(is_in(kwargs[key]))]
    return this_df


def pivot_data_to_time_series(data, unit_cols, time_col, feature_cols):
    this_data = data.copy()
    this_data = this_data.reset_index(level=this_data.index.names)
    return this_data.pivot_table(
        values=feature_cols,
        index=time_col,
        columns=unit_cols
    )


def unpivot_data(data, time_col):
    return data.T.stack(time_col).unstack(0)


def _add_single_lag(data, unit_cols, time_col, feature_cols, lags=1):
    this_data = data.copy()
    this_pivot_data = pivot_data_to_time_series(
        this_data, unit_cols, time_col, feature_cols)
    lagged_data = unpivot_data(this_pivot_data.shift(lags), time_col)
    lagged_data = lagged_data.rename(
        columns=dict([(c, 'l{}_{}'.format(lags, c)) for c in feature_cols]))
    this_data = this_data.set_index(unit_cols + [time_col])
    this_data = pd.merge(
        this_data, lagged_data, how='outer', on=unit_cols + [time_col])
    this_data = this_data.sort_index(level=unit_cols)
    return this_data.reset_index()


def add_lags(data, unit_cols, time_col, feature_cols, lags=1):
    if isinstance(lags, int):
        return _add_single_lag(data, unit_cols, time_col, feature_cols, lags)
    elif isinstance(lags, list):
        this_data = data.copy()
        for l in lags:
            this_data = _add_single_lag(
                this_data, unit_cols, time_col, feature_cols, lags=l)
        return this_data


def get_subgroup_stats(data, unit_id, feature, stat=np.mean):
    this_data = data.copy()
    if not isinstance(feature, list):
        feature = [feature]
    return this_data.groupby(unit_id)[feature].apply(stat)


def add_subgroup_stat(data, unit_id, feature,  name, stat=np.mean):
    this_data = data.copy()
    new_data = get_subgroup_stats(data, unit_id, feature, stat)
    new_data = new_data.rename(columns={feature: name})
    return this_data.merge(new_data, how='outer', on=unit_id)


def numeric_code_treatment(x):
    return treatment_list.index(x)


class EventTime:
    _ET_COLS = [CB.rank, CB.woke_time, CB.time_settled, CB.player_id]

    @classmethod
    def wake_time(cls, x):
        return x[CB.woke_time]

    @classmethod
    def settle_time(cls, x):
        return x[CB.time_settled]

    @classmethod
    def end_time_mturk_noq(cls, x):
        return 45

    @classmethod
    def stat_at_time(cls, data, time_func,  player_id, unit_id=None,
                     stat_func=None):
        this_data = data.copy()
        if unit_id is not None:
            d = dict(zip(unit_playergroup_rnd, [[u] for u in unit_id]))
            this_data = filter_data(this_data, **d)
        this_data = this_data.set_index(CB.player_id)
        t = time_func(this_data.loc[player_id])
        if pd.isnull(t):
            return np.nan
        return stat_func(this_data, t, player_id)

    @classmethod
    def rank_stat(cls, this_data, t, player_id):
        r = this_data.loc[player_id, CB.rank]
        n = filter_data(
            this_data, time_settled=is_less_than(t), rank=is_less_than(r)
        ).shape[0]
        return r - n

    @classmethod
    def num_settled_stat(cls, this_data, t, player_id):
        return filter_data(
            this_data, time_settled=is_less_than(t)
        ).shape[0]

    @classmethod
    def add_stat_at_time(cls, data, time_func, name=CB.rank_at_settlement,
                         stat_func=None):
        stat_func = cls.rank_stat if stat_func is None else stat_func
        this_data = data.copy()
        this_data[name] = 0
        gb = this_data[unit_playergroup_rnd + cls._ET_COLS].groupby(
            unit_playergroup_rnd)
        for _, g in gb:
            g[name] = g[CB.player_id].apply(
                cls._get_stat_from_group(g, time_func, stat_func))
            this_data.loc[g.index, name] = g[name]
        return this_data

    add_rank_at_time = add_stat_at_time

    @classmethod
    def _get_stat_from_group(cls, g, time_func, stat_func):
        def _get_stat(player_id):
            return cls.stat_at_time(
                g, time_func, player_id, stat_func=stat_func)
        return _get_stat

    @classmethod
    def vars_by_dom_time(cls, data, max_time):
        for j in range(0,max_time):
            def time_rank(x):
                return j
            new_var_name = 'rank_at_' + str(j) 
            data = ET.add_rank_at_time(data, time_rank, name=new_var_name)
            new_var_name = 'woke_by_bin_' + str(j)
            data[new_var_name] = (
                data[CB.woke_time] <=  j
            )    
            new_var_name = 'settlement_time_by_bin_' + str(j)
            data[new_var_name] = (
                data[CB.time_settled] <=  j
            )
            new_var_name = 'settlement_time_in_bin_' + str(j)
            data[new_var_name] = (
                (data[CB.time_settled] >  j )
                &
                (data[CB.time_settled] <=  (j+1))
            )
        data[CB.time_dom] = np.nan
        data[CB.dom] = False
        data[CB.wake_up_dom] = False
        for j in range(0, max_time):
            new_var_name = 'rank_at_' + str(j) 
            data.loc[
                data[new_var_name]<1.5, CB.dom] = True
            data.loc[
                (data[CB.dom]==True) & 
                (np.isnan(data[CB.time_dom])),
                CB.time_dom] = (j)
        data.loc[data[CB.woke_time] >= data[CB.time_dom], CB.wake_up_dom] = True
        data[CB.time_dom] = data[[CB.time_dom, CB.woke_time]].max(axis=1)
        data.loc[np.isnan(data[CB.time_dom]), CB.wake_up_dom] = np.nan 
        data[CB.ext_time_dom] = data[CB.time_dom].fillna(max_time)
        data[CB.settle_delay_from_dom] = data[CB.time_settled] - data[[CB.time_dom, CB.woke_time]].max(axis=1)
        data[CB.ext_settle_delay_from_dom] = data[CB.settle_delay_from_dom].fillna(max_time)
        return data

ET = EventTime


# data setup
df_mturk_noq = pd.read_csv('data/cleaned.csv')

def generate_features(data, domtime_vars=0, max_time=ET.end_time_mturk_noq(0)):
    data = data.loc[data[CB.rank] != 0, :]
    data = data.rename(columns={CB._treatment: CB.treatment})
    data.loc[:, CB.treatment] = data[CB.treatment].apply(
        to_treatment_name)
    data.loc[:, CB.amount_settled] = data[CB.amount_settled].fillna(0)
    data[CB.num_treatment] = data[CB.treatment].apply(
        numeric_code_treatment)
    data[CB.subgame] = data[CB.rnd].apply(lambda x: int((x - 1) / 6))
    data[CB.experience] = data[CB.rnd].apply(lambda x: (x - 1) % 6)
    mean_settled =  get_subgroup_stats(data, unit_playergroup_rnd, 'settled')
    data = add_subgroup_stat(data, unit_playergroup_rnd, 'settled', 'mean_settled')
    data = add_lags(data, unit_player_subgame, 'experience', ['settled', 'mean_settled'], lags=1)
    data = add_lags(data, unit_player_subgame, 'experience', ['settled', 'mean_settled'], lags=2)
    data['recent_other_settled'] = data[['l1_mean_settled', 'l2_mean_settled']].mean(axis=1)
    data['recent_own_settled'] = data[['l1_settled', 'l2_settled']].mean(axis=1)
    data = ET.add_rank_at_time(data, ET.settle_time)
    data = ET.add_rank_at_time(data, ET.wake_time, CB.rank_at_wake)
    if domtime_vars == 1:
        data = ET.vars_by_dom_time(data, max_time=max_time)    
    data[CB.ext_settle_time] = data[CB.time_settled].fillna(max_time)
    data[CB.settlement_delay] = data[CB.time_settled] - data[CB.woke_time]
    data[CB.ext_settle_delay] = data[CB.settlement_delay].fillna(max_time)
    return data

df_mturk_noq = generate_features(df_mturk_noq)
